package com.stylefeng.guns.modular.prediction.util;

import com.stylefeng.guns.core.util.DoubleUtil;
import org.ujmp.core.DenseMatrix;
import org.ujmp.core.Matrix;

import java.text.DecimalFormat;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.IntSummaryStatistics;
import java.util.List;
import java.util.stream.Collectors;

import static java.util.stream.Collectors.toList;

/**
 *
 */
public class MultiFactorForecasting {

	public static List<Double> forecast(List<Integer> accidentActualList, List<List<Integer>> factorActualLists,
			int after) {
		List<List<Double>> factorForecastLists = forecastFactors(factorActualLists);
		List<Double> coefficients = calculateCoefficient(accidentActualList, factorActualLists);

		int forecastSize = accidentActualList.size() + after;
		List<Double> forecasts = new ArrayList<>(forecastSize);

		for (int i = 0; i < forecastSize; i++) {
			double forecast = 0d;
			for (int j = 0; j < factorForecastLists.size(); j++) {
				forecast = DoubleUtil
						.add(forecast, DoubleUtil.mul(factorForecastLists.get(j).get(i), coefficients.get(j)));
			}
			forecasts.add(forecast);
		}

		forecasts = reserveTwoDecimal(forecasts);
		return forecasts;

	}

	private static List<Double> reserveTwoDecimal(List<Double> forecastResults) {
		forecastResults = forecastResults.stream().map(forecastResult -> {
			DecimalFormat df = new DecimalFormat("0.00");
			return Double.parseDouble(df.format(forecastResult));
		}).collect(Collectors.toList());
		return forecastResults;
	}

	private static List<List<Double>> forecastFactors(List<List<Integer>> factorActualLists) {
		return factorActualLists.stream().map(OneFactorForecasting::forecast).collect(toList());
	}

	/**
	 * 计算系数
	 * @param accidentActualList
	 * @param factorActualLists
	 * @return
	 */
	private static List<Double> calculateCoefficient(List<Integer> accidentActualList,
			List<List<Integer>> factorActualLists) {
		int xColumns = factorActualLists.size() + 1;
		int xRows = factorActualLists.get(0).size();

		Matrix xMatrix = DenseMatrix.Factory.zeros(xRows, xColumns);
		for (int i = 0; i < xRows; i++) {
			xMatrix.setAsInt(1, i, 0);
		}
		for (int cols = 1; cols <= factorActualLists.size(); cols++) {
			List<Integer> factorList = factorActualLists.get(cols - 1);
			for (int rows = 0; rows < factorList.size(); rows++) {
				xMatrix.setAsInt(factorList.get(rows), rows, cols);
			}
		}

		Matrix yMatrix = DenseMatrix.Factory.zeros(accidentActualList.size(), 1);
		for (int i = 0; i < accidentActualList.size(); i++) {
			yMatrix.setAsInt(accidentActualList.get(i), i, 0);
		}

		Matrix xT = xMatrix.transpose();
		Matrix xT_X = xT.mtimes(xMatrix);

		if (xT_X.det() == 0) {
			return calculateCoefficient2(accidentActualList, factorActualLists);
		}

		Matrix xT_X_INV = xT_X.inv();
		Matrix xT_X_INV_XT = xT_X_INV.mtimes(xT);
		Matrix xT_X_INV_XT_Y = xT_X_INV_XT.mtimes(yMatrix);

		Iterable<Object> iterable = xT_X_INV_XT_Y.allValues();

		List<Double> coefficients = new ArrayList<>();
		iterable.forEach(v -> {
			coefficients.add((Double) v);
		});

		return coefficients;
	}

	private static List<Double> calculateCoefficient2(List<Integer> accidentActualList,
			List<List<Integer>> factorActualLists) {
		List<List<Integer>> absDiffLists = new ArrayList<>();

		for (List<Integer> factorActual : factorActualLists) {
			List<Integer> absDiffList = new ArrayList<>();
			absDiffLists.add(absDiffList);
			for (int i = 0; i < factorActual.size(); i++) {
				Integer absDiff = Math.abs(factorActual.get(i) - accidentActualList.get(i));
				absDiffList.add(absDiff);
			}
		}

		IntSummaryStatistics intSummaryStatistics = absDiffLists.stream().flatMap(List::stream).mapToInt(i -> i)
				.summaryStatistics();

		int maxmax = intSummaryStatistics.getMax();
		int minmin = intSummaryStatistics.getMin();

		double p = 0.5;
		List<List<Double>> factorCorrelationLists = absDiffLists.stream().map(absDiffList -> absDiffList.stream()
				.map(k -> DoubleUtil.div(DoubleUtil.add(minmin, DoubleUtil.mul(p, maxmax)),
						DoubleUtil.add(k, DoubleUtil.mul(p, maxmax)))).collect(toList())).collect(toList());

		List<Double> factorAvgCorrelationList = factorCorrelationLists.stream()
				.map(factorCorrelationList -> factorCorrelationList.stream().mapToDouble(d -> d).average()
						.getAsDouble()).collect(toList());

		double avgCorrelationSum = factorAvgCorrelationList.stream().mapToDouble(d -> d).sum();

		List<Double> coefficientList = factorAvgCorrelationList.stream().map(i -> DoubleUtil.div(i, avgCorrelationSum))
				.collect(toList());

		return coefficientList;
	}


	public static void main(String[] args) {
		List<Integer> accidentList = Arrays.asList(6, 5, 5, 0);
		List<Integer> factor1List = Arrays.asList(1,1,2,1);
		List<Integer> factor2List = Arrays.asList(17,13,9,8);
		List<Integer> factor3List = Arrays.asList(7, 12, 14, 10);
		List<Integer> factor4List = Arrays.asList(2, 1, 1, 1);
		List<List<Integer>> factorLists = Arrays.asList(factor1List, factor2List, factor3List, factor4List);

		List<Double> forecast = forecast(accidentList, factorLists, 1);
		System.out.println(forecast);
	}


}